# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2025 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
"""Dask Distributed Tools"""

import logging
import os
import queue
import threading
from collections.abc import Iterable
from random import randint
from typing import Any

import dask
import toolz
from botocore.credentials import ReadOnlyCredentials
from dask.delayed import Delayed, DelayedLeaf, delayed
from dask.distributed import Client

__all__ = [
    "compute_tasks",
    "partition_map",
    "pmap",
    "save_blob_to_file",
    "save_blob_to_s3",
    "start_local_dask",
]

_LOG: logging.Logger = logging.getLogger(__name__)


def get_total_available_memory(check_jupyter_hub: bool = True) -> int:
    """Figure out how much memory is available
    1. Check MEM_LIMIT environment variable, set by jupyterhub
    2. Use hardware information if that not set
    """
    if check_jupyter_hub:
        mem_limit = os.environ.get("MEM_LIMIT", None)
        if mem_limit is not None:
            return int(mem_limit)

    from psutil import virtual_memory

    return virtual_memory().total


def compute_memory_per_worker(
    n_workers: int = 1,
    mem_safety_margin: str | int | None = None,
    memory_limit: str | int | None = None,
) -> int:
    """Figure out how much memory to assign per worker.

    result can be passed into ``memory_limit=`` parameter of dask worker/cluster/client
    """
    from dask.utils import parse_bytes

    if isinstance(memory_limit, str):
        memory_limit = parse_bytes(memory_limit)

    if isinstance(mem_safety_margin, str):
        mem_safety_margin = parse_bytes(mem_safety_margin)

    if memory_limit is None and mem_safety_margin is None:
        total_bytes = get_total_available_memory()
        # leave 500Mb or half of all memory if RAM is less than 1 Gb
        mem_safety_margin = min(500 * (1024 * 1024), total_bytes // 2)
    elif memory_limit is None:
        total_bytes = get_total_available_memory()
    elif mem_safety_margin is None:
        total_bytes = memory_limit
        mem_safety_margin = 0
    else:
        total_bytes = memory_limit
    assert mem_safety_margin is not None  # For typechecker.

    return (total_bytes - mem_safety_margin) // n_workers


def start_local_dask(
    n_workers: int = 1,
    threads_per_worker: int | None = None,
    mem_safety_margin: str | int | None = None,
    memory_limit: str | int | None = None,
    **kw,
) -> Client:
    """
    Wrapper around ``distributed.Client(..)`` constructor that deals with memory better.

    It also configures ``distributed.dashboard.link`` to go over proxy when operating
    from behind jupyterhub.

    :param n_workers: number of worker processes to launch
    :param threads_per_worker: number of threads per worker, default is as many as there are CPUs
    :param memory_limit: maximum memory to use across all workers
    :param mem_safety_margin: bytes to reserve for the rest of the system, only applicable
                              if ``memory_limit=`` is not supplied.

    .. note::

        if ``memory_limit=`` is supplied, it will be parsed and divided equally between workers.
    """
    # if dashboard.link set to default value and running behind hub, make dashboard link go via proxy
    if (
        dask.config.get("distributed.dashboard.link")
        == "{scheme}://{host}:{port}/status"
    ):
        jup_prefix = os.environ.get("JUPYTERHUB_SERVICE_PREFIX")
        if jup_prefix is not None:
            jup_prefix = jup_prefix.rstrip("/")
            dask.config.set(
                {"distributed.dashboard.link": f"{jup_prefix}/proxy/{{port}}/status"}
            )

    memory_limit = compute_memory_per_worker(
        n_workers=n_workers,
        memory_limit=memory_limit,
        mem_safety_margin=mem_safety_margin,
    )

    return Client(
        n_workers=n_workers,
        threads_per_worker=threads_per_worker,
        memory_limit=memory_limit,
        **kw,
    )


def _randomize(prefix: str) -> str:
    return f"{prefix}-{randint(0, 0xFFFFFFFF):08x}"


def partition_map(
    n: int, func: Any, its: Iterable[Any], name: str = "compute"
) -> Iterable[Any]:
    """
    Parallel map in lumps.

    Partition sequence into lumps of size ``n``, then construct dask delayed computation evaluating to:

    .. code-block:: python

        [func(x) for x in its[0:1n]],
        [func(x) for x in its[n:2n]],
        ...
        [func(x) for x in its[..]],

    This is useful when you need to process a large number of small (quick) tasks (pixel drill for example).

    :param n: number of elements to process in one go
    :param func: Function to apply (non-dask)
    :param its:  Values to feed to fun
    :param name: How the computation should be named in dask visualizations

    Returns
    -------

    Iterator of ``dask.Delayed`` objects.
    """

    def lump_proc(dd):
        return [func(d) for d in dd]

    proc = delayed(lump_proc, nout=1, pure=True)
    data_name = _randomize("data_" + name)
    name = _randomize(name)

    for i, dd in enumerate(toolz.partition_all(n, its)):
        lump = delayed(dd, pure=True, traverse=False, name=data_name + str(i))
        yield proc(lump, dask_key_name=name + str(i))


def compute_tasks(
    tasks: Iterable[Any], client: Client, max_in_flight: int = 3
) -> Iterable[Any]:
    """Parallel compute stream with back pressure.

    Equivalent to:


    .. code-block:: python

        (client.compute(task).result() for task in tasks)

    but with up to ``max_in_flight`` tasks being processed at the same time.
    Input/Output order is preserved, so there is a possibility of head of
    line blocking.

    .. note::

          lower limit is 3 concurrent tasks to simplify implementation,
          there is no point calling this function if you want one active
          task and supporting exactly 2 active tasks is not worth the complexity,
          for now. We might special-case 2 at some point.
    """
    # New thread:
    #    1. Take dask task from iterator
    #    2. Submit to client for processing
    #    3. Send it of to wrk_q
    #
    # Calling thread:
    #    1. Pull scheduled future from wrk_q
    #    2. Wait for result of the future
    #    3. yield result to calling code
    from .generic import it2q, qmap

    # (max_in_flight - 2) -- one on each side of queue
    wrk_q: queue.Queue = queue.Queue(maxsize=max(1, max_in_flight - 2))

    # fifo_timeout='0ms' ensures that priority of later tasks is lower
    futures = (client.compute(task, fifo_timeout="0ms") for task in tasks)

    in_thread = threading.Thread(target=it2q, args=(futures, wrk_q))
    in_thread.start()

    yield from qmap(lambda f: f.result(), wrk_q)

    in_thread.join()


def pmap(
    func: Any,
    its: Iterable[Any],
    client: Client,
    lump: int = 1,
    max_in_flight: int = 3,
    name: str = "compute",
) -> Iterable[Any]:
    """Parallel map with back pressure.

    Equivalent to this:

    .. code-block:: python

       (func(x) for x in its)

    Except that ``func(x)`` runs concurrently on dask cluster.

    :param func:   Method that will be applied concurrently to data from ``its``
    :param its:    Iterator of input values
    :param client: Connected dask client
    :param lump:   Group this many datasets into one task
    :param max_in_flight: Maximum number of active tasks to submit
    :param name:   Dask name for computation
    """
    max_in_flight = max_in_flight // lump

    tasks = partition_map(lump, func, its, name=name)

    for xx in compute_tasks(tasks, client=client, max_in_flight=max_in_flight):
        yield from xx


def _save_blob_to_file(
    data: bytes | str, fname: str, with_deps=None
) -> tuple[str, bool]:
    if isinstance(data, str):
        data = data.encode("utf8")

    try:
        with open(fname, "wb") as f:
            f.write(data)
    except OSError:
        return fname, False

    return fname, True


def _save_blob_to_s3(
    data: bytes | str,
    url: str,
    profile: str | None = None,
    creds: ReadOnlyCredentials | None = None,
    region_name: str | None = None,
    with_deps=None,
    **kw,
) -> tuple[str, bool]:
    from botocore.errorfactory import ClientError
    from botocore.exceptions import BotoCoreError

    from .aws import s3_client, s3_dump

    try:
        s3 = s3_client(
            profile=profile, creds=creds, region_name=region_name, cache=True
        )

        result = s3_dump(data, url, s3=s3, **kw)
    except (OSError, BotoCoreError, ClientError):
        result = False

    return url, result


_save_blob_to_file_delayed = delayed(
    _save_blob_to_file, name="save-to-disk", pure=False
)
_save_blob_to_s3_delayed = delayed(_save_blob_to_s3, name="save-to-s3", pure=False)


def save_blob_to_file(
    data: bytes | str, fname: str, with_deps=None
) -> Delayed | DelayedLeaf:
    """
    Dump from memory to local filesystem as a dask delayed operation.

    :param data: Data blob to save to file (have to fit into memory all at once),
                 strings will be saved in UTF8 format.

    :param fname: Path to file

    :param with_deps: Useful for introducing dependencies into dask graph,
                      for example save yaml file after saving all tiff files.

    Returns
    -------
    ``(FilePath, True)`` tuple on success
    ``(FilePath, False)`` on any error


    .. note::

       Dask workers must be local or have network filesystem mounted in
       the same path as calling code.
    """
    return _save_blob_to_file_delayed(data, fname, with_deps=with_deps)


def save_blob_to_s3(
    data: bytes | str,
    url: str,
    profile: str | None = None,
    creds=None,
    region_name: str | None = None,
    with_deps=None,
    **kw,
) -> Delayed | DelayedLeaf:
    """
    Dump from memory to S3 as a dask delayed operation.

    :param data:        Data blob to save to file (have to fit into memory all at once)
    :param url:         Url in a form s3://bucket/path/to/file
    :param profile:     Profile name to lookup (only used if session is not supplied)
    :param creds:       Override credentials with supplied data
    :param region_name: Region name to use, overrides session setting

    :param with_deps:   Useful for introducing dependencies into dask graph,
                        for example save yaml file after saving all tiff files.
    :param kw:          Passed on to ``s3.put_object(..)``, useful for things like ContentType/ACL

    Returns
    -------

    ``(url, True)`` tuple on success
    ``(url, False)`` on any error
    """
    return _save_blob_to_s3_delayed(
        data,
        url,
        profile=profile,
        creds=creds,
        region_name=region_name,
        with_deps=with_deps,
        **kw,
    )
